{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#default_exp data.load"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from fastai.torch_basics import *\n",
    "\n",
    "from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter,_SingleProcessDataLoaderIter,_DatasetKind\n",
    "_loaders = (_MultiProcessingDataLoaderIter,_SingleProcessDataLoaderIter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "from nbdev.showdoc import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 4\n",
    "letters = list(string.ascii_lowercase)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## DataLoader helpers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "fastai includes a replacement for Pytorch's *DataLoader* which is largely API-compatible, and adds a lot of useful functionality and flexibility. Before we look at the class, there are a couple of helpers we'll need to define."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def _wif(worker_id):\n",
    "    set_num_threads(1)\n",
    "    info = get_worker_info()\n",
    "    ds = info.dataset.d\n",
    "    ds.num_workers,ds.offs = info.num_workers,info.id\n",
    "    set_seed(info.seed)\n",
    "    ds.wif()\n",
    "\n",
    "class _FakeLoader:\n",
    "    _IterableDataset_len_called,_auto_collation,collate_fn,drop_last = None,False,noops,False\n",
    "    _index_sampler,generator,prefetch_factor = Inf.count,None,2\n",
    "    dataset_kind = _dataset_kind = _DatasetKind.Iterable\n",
    "    def __init__(self, d, pin_memory, num_workers, timeout):\n",
    "        self.dataset,self.default,self.worker_init_fn = self,d,_wif\n",
    "        store_attr('d,pin_memory,num_workers,timeout')\n",
    "\n",
    "    def __iter__(self): return iter(self.d.create_batches(self.d.sample()))\n",
    "\n",
    "    @property\n",
    "    def multiprocessing_context(self): return (None,multiprocessing)[self.num_workers>0]\n",
    "\n",
    "    @contextmanager\n",
    "    def no_multiproc(self):\n",
    "        old_num_workers = self.num_workers\n",
    "        try:\n",
    "            self.num_workers = 0\n",
    "            yield self.d\n",
    "        finally: self.num_workers = old_num_workers\n",
    "\n",
    "_collate_types = (ndarray, Tensor, typing.Mapping, str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def fa_collate(t):\n",
    "    \"A replacement for PyTorch `default_collate` which maintains types and handles `Sequence`s\"\n",
    "    b = t[0]\n",
    "    return (default_collate(t) if isinstance(b, _collate_types)\n",
    "            else type(t[0])([fa_collate(s) for s in zip(*t)]) if isinstance(b, Sequence)\n",
    "            else default_collate(t))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#e.g. x is int, y is tuple\n",
    "t = [(1,(2,3)),(1,(2,3))]\n",
    "test_eq(fa_collate(t), default_collate(t))\n",
    "test_eq(L(fa_collate(t)).map(type), [Tensor,tuple])\n",
    "\n",
    "t = [(1,(2,(3,4))),(1,(2,(3,4)))]\n",
    "test_eq(fa_collate(t), default_collate(t))\n",
    "test_eq(L(fa_collate(t)).map(type), [Tensor,tuple])\n",
    "test_eq(L(fa_collate(t)[1]).map(type), [Tensor,tuple])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def fa_convert(t):\n",
    "    \"A replacement for PyTorch `default_convert` which maintains types and handles `Sequence`s\"\n",
    "    return (default_convert(t) if isinstance(t, _collate_types)\n",
    "            else type(t)([fa_convert(s) for s in t]) if isinstance(t, Sequence)\n",
    "            else default_convert(t))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t0 = array([1,2])\n",
    "t = [t0,(t0,t0)]\n",
    "\n",
    "test_eq(fa_convert(t), default_convert(t))\n",
    "test_eq(L(fa_convert(t)).map(type), [Tensor,tuple])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class SkipItemException(Exception):\n",
    "    \"Raised to notify `DataLoader` to skip an item\"\n",
    "    pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h3 id=\"SkipItemException\" class=\"doc_header\"><code>class</code> <code>SkipItemException</code><a href=\"\" class=\"source_link\" style=\"float:right\">[source]</a></h3>\n",
       "\n",
       "> <code>SkipItemException</code>() :: `Exception`\n",
       "\n",
       "Raised to notify [`DataLoader`](/data.load.html#DataLoader) to skip an item"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(SkipItemException, title_level=3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## DataLoader -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@log_args(but='dataset,wif,create_batch,create_batches,create_item,retain,get_idxs,sample,shuffle_fn,do_batch')\n",
    "@funcs_kwargs\n",
    "class DataLoader(GetAttr):\n",
    "    _noop_methods = 'wif before_iter after_item before_batch after_batch after_iter'.split()\n",
    "    for o in _noop_methods: exec(f\"def {o}(self, x=None, *args, **kwargs): return x\")\n",
    "    _methods = _noop_methods + 'create_batches create_item create_batch retain \\\n",
    "        get_idxs sample shuffle_fn do_batch create_batch'.split()\n",
    "    _default = 'dataset'\n",
    "    def __init__(self, dataset=None, bs=None, num_workers=0, pin_memory=False, timeout=0, batch_size=None,\n",
    "                 shuffle=False, drop_last=False, indexed=None, n=None, device=None, **kwargs):\n",
    "        if batch_size is not None: bs = batch_size # PyTorch compatibility\n",
    "        assert not (bs is None and drop_last)\n",
    "        if indexed is None: indexed = dataset is not None and hasattr(dataset,'__getitem__')\n",
    "        if n is None:\n",
    "            try: n = len(dataset)\n",
    "            except TypeError: pass\n",
    "        store_attr('dataset,bs,shuffle,drop_last,indexed,n,pin_memory,timeout,device')\n",
    "        self.rng,self.num_workers,self.offs = random.Random(random.randint(0,2**32-1)),1,0\n",
    "        self.fake_l = _FakeLoader(self, pin_memory, num_workers, timeout)\n",
    "\n",
    "    def __len__(self):\n",
    "        if self.n is None: raise TypeError\n",
    "        if self.bs is None: return self.n\n",
    "        return self.n//self.bs + (0 if self.drop_last or self.n%self.bs==0 else 1)\n",
    "\n",
    "    def get_idxs(self):\n",
    "        idxs = Inf.count if self.indexed else Inf.nones\n",
    "        if self.n is not None: idxs = list(itertools.islice(idxs, self.n))\n",
    "        if self.shuffle: idxs = self.shuffle_fn(idxs)\n",
    "        return idxs\n",
    "    \n",
    "    def sample(self): \n",
    "        return (b for i,b in enumerate(self.__idxs) if i//(self.bs or 1)%self.num_workers==self.offs)\n",
    "\n",
    "    def __iter__(self):\n",
    "        self.randomize()\n",
    "        self.before_iter()\n",
    "        self.__idxs=self.get_idxs() # called in context of main process (not workers/subprocesses)\n",
    "        for b in _loaders[self.fake_l.num_workers==0](self.fake_l):\n",
    "            if self.device is not None: b = to_device(b, self.device)\n",
    "            yield self.after_batch(b)\n",
    "        self.after_iter()\n",
    "        if hasattr(self, 'it'): del(self.it)\n",
    "\n",
    "    def create_batches(self, samps):\n",
    "        self.it = iter(self.dataset) if self.dataset is not None else None\n",
    "        res = filter(lambda o:o is not None, map(self.do_item, samps))\n",
    "        yield from map(self.do_batch, self.chunkify(res))\n",
    "\n",
    "    def new(self, dataset=None, cls=None, **kwargs):\n",
    "        if dataset is None: dataset = self.dataset\n",
    "        if cls is None: cls = type(self)\n",
    "        cur_kwargs = dict(dataset=dataset, num_workers=self.fake_l.num_workers, pin_memory=self.pin_memory, timeout=self.timeout,\n",
    "                          bs=self.bs, shuffle=self.shuffle, drop_last=self.drop_last, indexed=self.indexed, device=self.device)\n",
    "        for n in self._methods: cur_kwargs[n] = getattr(self, n)\n",
    "        return cls(**merge(cur_kwargs, kwargs))\n",
    "\n",
    "    @property\n",
    "    def prebatched(self): return self.bs is None\n",
    "    def do_item(self, s):\n",
    "        try: return self.after_item(self.create_item(s))\n",
    "        except SkipItemException: return None\n",
    "    def chunkify(self, b): return b if self.prebatched else chunked(b, self.bs, self.drop_last)\n",
    "    def shuffle_fn(self, idxs): return self.rng.sample(idxs, len(idxs))\n",
    "    def randomize(self): self.rng = random.Random(self.rng.randint(0,2**32-1))\n",
    "    def retain(self, res, b):  return retain_types(res, b[0] if is_listy(b) else b)\n",
    "    def create_item(self, s):  return next(self.it) if s is None else self.dataset[s]\n",
    "    def create_batch(self, b): return (fa_collate,fa_convert)[self.prebatched](b)\n",
    "    def do_batch(self, b): return self.retain(self.create_batch(self.before_batch(b)), b)\n",
    "    def to(self, device): self.device = device\n",
    "    def one_batch(self):\n",
    "        if self.n is not None and len(self)==0: raise ValueError(f'This DataLoader does not contain any batches')\n",
    "        with self.fake_l.no_multiproc(): res = first(self)\n",
    "        if hasattr(self, 'it'): delattr(self, 'it')\n",
    "        return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "add_docs(DataLoader, \"API compatible with PyTorch DataLoader, with a lot more callbacks and flexibility\",\n",
    "         get_idxs       = \"Return a list of indices to reference the dataset. Calls `shuffle_fn` internally if `shuffle=True`.\",\n",
    "         sample         = \"Same as `get_idxs` but returns a generator of indices to reference the dataset.\",\n",
    "         create_batches = \"Takes output of `sample` as input, and returns batches of data. Does not apply `after_batch`.\",\n",
    "         new            = \"Create a new `DataLoader` with given arguments keeping remaining arguments same as original `DataLoader`.\",\n",
    "         prebatched     = \"Check if `bs` is None.\",\n",
    "         do_item        = \"Combines `after_item` and `create_item` to get an item from dataset by providing index as input.\",\n",
    "         chunkify       = \"Used by `create_batches` to turn generator of items (`b`) into batches.\",\n",
    "         shuffle_fn     = \"Returns a random permutation of `idxs`.\",\n",
    "         randomize      = \"Set's `DataLoader` random number generator state.\",\n",
    "         retain         = \"Cast each item of `res` to type of matching item in `b` if its a superclass.\",\n",
    "         create_item    = \"Return a subset of the dataset containing the index values of the sample if there are samples, else return the next iterator.\",\n",
    "         create_batch   = \"Collate a list of items into a batch.\",\n",
    "         do_batch       = \"Combines `create_batch` and `before_batch` to get a batch of items. Input is a list of items to collate.\",\n",
    "         to             = \"Sets `self.device=device`.\",\n",
    "         one_batch      = \"Return one batch from `DataLoader`.\",\n",
    "         wif            = \"See pytorch `worker_init_fn` for details (https://pytorch.org/docs/stable/data.html#multi-process-data-loading).\", \n",
    "         before_iter    = \"Called before `DataLoader` starts to read/iterate over the dataset.\",\n",
    "         after_item     = \"Takes output of `create_item` as input and applies this function on it.\",\n",
    "         before_batch   = \"It is called before collating a list of items into a batch. Input is a list of items.\",\n",
    "         after_batch    = \"After collating mini-batch of items, the mini-batch is passed through this function.\",\n",
    "         after_iter     = \"Called after `DataLoader` has fully read/iterated over the dataset.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Arguments to `DataLoader`:\n",
    "* `dataset`: dataset from which to load the data. Can be either map-style or iterable-style dataset.\n",
    "* `bs` (int): how many samples per batch to load (if `batch_size` is provided then `batch_size` will override `bs`). If `bs=None`, then it is assumed that `dataset.__getitem__` returns a batch.\n",
    "* `num_workers` (int): how many subprocesses to use for data loading. `0` means that the data will be loaded in the main process.\n",
    "* `pin_memory` (bool): If `True`, the data loader will copy Tensors into CUDA pinned memory before returning them.\n",
    "* `timeout` (float>0): the timeout value in seconds for collecting a batch from workers.\n",
    "* `batch_size` (int): It is only provided for PyTorch compatibility. Use `bs`.\n",
    "* `shuffle` (bool): If `True`, then data is shuffled every time dataloader is fully read/iterated.\n",
    "* `drop_last` (bool): If `True`, then the last incomplete batch is dropped.\n",
    "* `indexed` (bool): Set to `False`, if you are using iterable-style dataset. Otherwise it is set to `True` by default.\n",
    "* `n` (int): Defaults to `len(dataset)`. If you are using iterable-style dataset, you can specify the size of batch using `n`.\n",
    "* `device` (torch.device): Defaults to `default_device()` which is CUDA by default. You can specify device as `torch.device('cpu')."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Override `get_idxs` to return the same index until consumption of the DL. This is intented to test consistent sampling behavior when `num_workers`>1. Note it does not need to use `self.rng` anymore to maintain consistent behavior across workers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class AdamantDL(DataLoader):\n",
    "    def get_idxs(self):\n",
    "        r=random.randint(0,self.n-1)\n",
    "        return [r] * self.n\n",
    "test_eq(torch.cat(tuple(AdamantDL((list(range(50))),bs=16,num_workers=4))).unique().numel(),1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Override `item` and use the default infinite sampler to get a stream of unknown length (`stop()` when you want to stop the stream)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(#15) [0.0731038973596212,0.24229693635490512,0.3284360867268481,0.5409560460233285,0.30370969993353036,0.1483296960825845,0.39379237978935255,0.49859287529086227,0.5569261363148857,0.6875843672764917...]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class RandDL(DataLoader):\n",
    "    def create_item(self, s):\n",
    "        r = random.random()\n",
    "        return r if r<0.95 else stop()\n",
    "\n",
    "L(RandDL())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(#0) []"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "L(RandDL(bs=4, drop_last=True)).map(len)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(#26) [4,4,4,4,4,4,4,4,4,4...]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dl = RandDL(bs=4, num_workers=4, drop_last=True)\n",
    "L(dl).map(len)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(dl.fake_l.num_workers, 4)\n",
    "with dl.fake_l.no_multiproc(): \n",
    "    test_eq(dl.fake_l.num_workers, 0)\n",
    "    L(dl).map(len)\n",
    "test_eq(dl.fake_l.num_workers, 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(#129) [0.3826905520773186,0.5186695293970484,0.9163681581510872,0.04076529388170569,0.5748628515084315,0.6413833614263649,0.007556592946413199,0.8198059240321355,0.5104851763036267,0.5743053801036588...]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def _rand_item(s):\n",
    "    r = random.random()\n",
    "    return r if r<0.95 else stop()\n",
    "\n",
    "L(DataLoader(create_item=_rand_item))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you don't set `bs`, then `dataset` is assumed to provide an iterator or a `__getitem__` that returns a batch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds1 = DataLoader(letters)\n",
    "test_eq(L(ds1), letters)\n",
    "test_eq(len(ds1), 26)\n",
    "\n",
    "test_shuffled(L(DataLoader(letters, shuffle=True)), letters)\n",
    "\n",
    "ds1 = DataLoader(letters, indexed=False)\n",
    "test_eq(L(ds1), letters)\n",
    "test_eq(len(ds1), 26)\n",
    "\n",
    "t2 = L(tensor([0,1,2]),tensor([3,4,5]))\n",
    "ds2 = DataLoader(t2)\n",
    "test_eq_type(L(ds2), t2)\n",
    "\n",
    "t3 = L(array([0,1,2]),array([3,4,5]))\n",
    "ds3 = DataLoader(t3)\n",
    "test_eq_type(L(ds3), t3.map(tensor))\n",
    "\n",
    "ds4 = DataLoader(t3, create_batch=noop, after_iter=lambda: setattr(t3, 'f', 1))\n",
    "test_eq_type(L(ds4), t3)\n",
    "test_eq(t3.f, 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you do set `bs`, then `dataset` is assumed to provide an iterator or a `__getitem__` that returns a single item of a batch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def twoepochs(d): return ' '.join(''.join(list(o)) for _ in range(2) for o in d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds1 = DataLoader(letters, bs=4, drop_last=True, num_workers=0)\n",
    "test_eq(twoepochs(ds1), 'abcd efgh ijkl mnop qrst uvwx abcd efgh ijkl mnop qrst uvwx')\n",
    "\n",
    "ds1 = DataLoader(letters,4,num_workers=2)\n",
    "test_eq(twoepochs(ds1), 'abcd efgh ijkl mnop qrst uvwx yz abcd efgh ijkl mnop qrst uvwx yz')\n",
    "\n",
    "ds1 = DataLoader(range(12), bs=4, num_workers=3)\n",
    "test_eq_type(L(ds1), L(tensor([0,1,2,3]),tensor([4,5,6,7]),tensor([8,9,10,11])))\n",
    "\n",
    "ds1 = DataLoader([str(i) for i in range(11)], bs=4, after_iter=lambda: setattr(t3, 'f', 2))\n",
    "test_eq_type(L(ds1), L(['0','1','2','3'],['4','5','6','7'],['8','9','10']))\n",
    "test_eq(t3.f, 2)\n",
    "\n",
    "it = iter(DataLoader(map(noop,range(20)), bs=4, num_workers=1))\n",
    "test_eq_type([next(it) for _ in range(3)], [tensor([0,1,2,3]),tensor([4,5,6,7]),tensor([8,9,10,11])])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 4.36 ms, sys: 884 µs, total: 5.24 ms\n",
      "Wall time: 329 ms\n",
      "CPU times: user 11.2 ms, sys: 18.4 ms, total: 29.6 ms\n",
      "Wall time: 172 ms\n",
      "CPU times: user 16.2 ms, sys: 18 ms, total: 34.2 ms\n",
      "Wall time: 112 ms\n"
     ]
    }
   ],
   "source": [
    "class SleepyDL(list):\n",
    "    def __getitem__(self,i):\n",
    "        time.sleep(random.random()/50)\n",
    "        return super().__getitem__(i)\n",
    "\n",
    "t = SleepyDL(letters)\n",
    "\n",
    "%time test_eq(DataLoader(t, num_workers=0), letters)\n",
    "%time test_eq(DataLoader(t, num_workers=2), letters)\n",
    "%time test_eq(DataLoader(t, num_workers=4), letters)\n",
    "\n",
    "dl = DataLoader(t, shuffle=True, num_workers=1)\n",
    "test_shuffled(L(dl), letters)\n",
    "test_shuffled(L(dl), L(dl))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 16.7 ms, sys: 17.3 ms, total: 34 ms\n",
      "Wall time: 81.4 ms\n"
     ]
    }
   ],
   "source": [
    "class SleepyQueue():\n",
    "    \"Simulate a queue with varying latency\"\n",
    "    def __init__(self, q): self.q=q\n",
    "    def __iter__(self):\n",
    "        while True:\n",
    "            time.sleep(random.random()/100)\n",
    "            try: yield self.q.get_nowait()\n",
    "            except queues.Empty: return\n",
    "\n",
    "q = Queue()\n",
    "for o in range(30): q.put(o)\n",
    "it = SleepyQueue(q)\n",
    "\n",
    "%time test_shuffled(L(DataLoader(it, num_workers=4)), range(30))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class A(TensorBase): pass\n",
    "\n",
    "for nw in (0,2):\n",
    "    t = A(tensor([1,2]))\n",
    "    dl = DataLoader([t,t,t,t,t,t,t,t], bs=4, num_workers=nw)\n",
    "    b = first(dl)\n",
    "    test_eq(type(b), A)\n",
    "\n",
    "    t = (A(tensor([1,2])),)\n",
    "    dl = DataLoader([t,t,t,t,t,t,t,t], bs=4, num_workers=nw)\n",
    "    b = first(dl)\n",
    "    test_eq(type(b[0]), A)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "list(DataLoader(list(range(50)),bs=32,shuffle=True,num_workers=3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class A(TensorBase): pass\n",
    "t = A(tensor(1,2))\n",
    "\n",
    "tdl = DataLoader([t,t,t,t,t,t,t,t], bs=4, num_workers=2, after_batch=to_device)\n",
    "b = first(tdl)\n",
    "test_eq(type(b), A)\n",
    "\n",
    "# Unknown attributes are delegated to `dataset`\n",
    "test_eq(tdl.pop(), tensor(1,2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Converted 00_torch_core.ipynb.\n",
      "Converted 01_layers.ipynb.\n",
      "Converted 02_data.load.ipynb.\n",
      "Converted 03_data.core.ipynb.\n",
      "Converted 04_data.external.ipynb.\n",
      "Converted 05_data.transforms.ipynb.\n",
      "Converted 06_data.block.ipynb.\n",
      "Converted 07_vision.core.ipynb.\n",
      "Converted 08_vision.data.ipynb.\n",
      "Converted 09_vision.augment.ipynb.\n",
      "Converted 09b_vision.utils.ipynb.\n",
      "Converted 09c_vision.widgets.ipynb.\n",
      "Converted 10_tutorial.pets.ipynb.\n",
      "Converted 10b_tutorial.albumentations.ipynb.\n",
      "Converted 11_vision.models.xresnet.ipynb.\n",
      "Converted 12_optimizer.ipynb.\n",
      "Converted 13_callback.core.ipynb.\n",
      "Converted 13a_learner.ipynb.\n",
      "Converted 13b_metrics.ipynb.\n",
      "Converted 14_callback.schedule.ipynb.\n",
      "Converted 14a_callback.data.ipynb.\n",
      "Converted 15_callback.hook.ipynb.\n",
      "Converted 15a_vision.models.unet.ipynb.\n",
      "Converted 16_callback.progress.ipynb.\n",
      "Converted 17_callback.tracker.ipynb.\n",
      "Converted 18_callback.fp16.ipynb.\n",
      "Converted 18a_callback.training.ipynb.\n",
      "Converted 18b_callback.preds.ipynb.\n",
      "Converted 19_callback.mixup.ipynb.\n",
      "Converted 20_interpret.ipynb.\n",
      "Converted 20a_distributed.ipynb.\n",
      "Converted 21_vision.learner.ipynb.\n",
      "Converted 22_tutorial.imagenette.ipynb.\n",
      "Converted 23_tutorial.vision.ipynb.\n",
      "Converted 24_tutorial.siamese.ipynb.\n",
      "Converted 24_vision.gan.ipynb.\n",
      "Converted 30_text.core.ipynb.\n",
      "Converted 31_text.data.ipynb.\n",
      "Converted 32_text.models.awdlstm.ipynb.\n",
      "Converted 33_text.models.core.ipynb.\n",
      "Converted 34_callback.rnn.ipynb.\n",
      "Converted 35_tutorial.wikitext.ipynb.\n",
      "Converted 36_text.models.qrnn.ipynb.\n",
      "Converted 37_text.learner.ipynb.\n",
      "Converted 38_tutorial.text.ipynb.\n",
      "Converted 39_tutorial.transformers.ipynb.\n",
      "Converted 40_tabular.core.ipynb.\n",
      "Converted 41_tabular.data.ipynb.\n",
      "Converted 42_tabular.model.ipynb.\n",
      "Converted 43_tabular.learner.ipynb.\n",
      "Converted 44_tutorial.tabular.ipynb.\n",
      "Converted 45_collab.ipynb.\n",
      "Converted 46_tutorial.collab.ipynb.\n",
      "Converted 50_tutorial.datablock.ipynb.\n",
      "Converted 60_medical.imaging.ipynb.\n",
      "Converted 61_tutorial.medical_imaging.ipynb.\n",
      "Converted 65_medical.text.ipynb.\n",
      "Converted 70_callback.wandb.ipynb.\n",
      "Converted 71_callback.tensorboard.ipynb.\n",
      "Converted 72_callback.neptune.ipynb.\n",
      "Converted 73_callback.captum.ipynb.\n",
      "Converted 74_callback.cutmix.ipynb.\n",
      "Converted 97_test_utils.ipynb.\n",
      "Converted 99_pytorch_doc.ipynb.\n",
      "Converted dev-setup.ipynb.\n",
      "Converted index.ipynb.\n",
      "Converted quick_start.ipynb.\n",
      "Converted tutorial.ipynb.\n"
     ]
    }
   ],
   "source": [
    "#hide\n",
    "from nbdev.export import notebook2script\n",
    "notebook2script()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
